{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型保存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)\n",
      "matplotlib 3.1.2\n",
      "numpy 1.18.1\n",
      "pandas 0.25.3\n",
      "sklearn 0.22.1\n",
      "tensorflow 2.0.0\n",
      "tensorflow_core.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 28, 28) (5000,)\n",
      "(55000, 28, 28) (55000,)\n",
      "(10000, 28, 28) (10000,)\n"
     ]
    }
   ],
   "source": [
    "fashion_mnist = keras.datasets.fashion_mnist\n",
    "(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()\n",
    "x_valid,x_train = x_train_all[:5000],x_train_all[5000:]\n",
    "y_valid,y_train = y_train_all[:5000],y_train_all[5000:]\n",
    "\n",
    "print(x_valid.shape,y_valid.shape)\n",
    "print(x_train.shape,y_train.shape)\n",
    "print(x_test.shape,y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "scaler = StandardScaler()\n",
    "x_train_scaled = scaler.fit_transform(\n",
    "    x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)\n",
    "x_valid_scaled = scaler.transform(\n",
    "    x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)\n",
    "x_test_scaled = scaler.transform(\n",
    "    x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Keras模型保存--保存参数或保存模型结构 + 参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten (Flatten)            (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 300)               235500    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 100)               30100     \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 10)                1010      \n",
      "=================================================================\n",
      "Total params: 266,610\n",
      "Trainable params: 266,610\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Flatten(input_shape=[28, 28]),\n",
    "    keras.layers.Dense(300, activation='relu'),\n",
    "    keras.layers.Dense(100, activation='relu'),\n",
    "    keras.layers.Dense(10, activation='softmax')\n",
    "])\n",
    "model.summary()\n",
    "\n",
    "# 模型编译，固化模型\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "              optimizer='sgd',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1、保存模型结构 + 参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 55000 samples, validate on 5000 samples\n",
      "Epoch 1/10\n",
      "55000/55000 [==============================] - 7s 133us/sample - loss: 0.5436 - accuracy: 0.8077 - val_loss: 0.4384 - val_accuracy: 0.8344\n",
      "Epoch 2/10\n",
      "55000/55000 [==============================] - 6s 111us/sample - loss: 0.3931 - accuracy: 0.8585 - val_loss: 0.3657 - val_accuracy: 0.8726\n",
      "Epoch 3/10\n",
      "55000/55000 [==============================] - 6s 110us/sample - loss: 0.3552 - accuracy: 0.8710 - val_loss: 0.3435 - val_accuracy: 0.8770\n",
      "Epoch 4/10\n",
      "55000/55000 [==============================] - 6s 110us/sample - loss: 0.3299 - accuracy: 0.8804 - val_loss: 0.3451 - val_accuracy: 0.8758\n",
      "Epoch 5/10\n",
      "55000/55000 [==============================] - 6s 108us/sample - loss: 0.3109 - accuracy: 0.8867 - val_loss: 0.3318 - val_accuracy: 0.8800\n",
      "Epoch 6/10\n",
      "55000/55000 [==============================] - 6s 109us/sample - loss: 0.2933 - accuracy: 0.8933 - val_loss: 0.3184 - val_accuracy: 0.8838\n",
      "Epoch 7/10\n",
      "55000/55000 [==============================] - 6s 108us/sample - loss: 0.2797 - accuracy: 0.8977 - val_loss: 0.3114 - val_accuracy: 0.8884\n",
      "Epoch 8/10\n",
      "55000/55000 [==============================] - 6s 110us/sample - loss: 0.2686 - accuracy: 0.9015 - val_loss: 0.3080 - val_accuracy: 0.8866\n",
      "Epoch 9/10\n",
      "55000/55000 [==============================] - 6s 109us/sample - loss: 0.2556 - accuracy: 0.9071 - val_loss: 0.3121 - val_accuracy: 0.8858\n",
      "Epoch 10/10\n",
      "55000/55000 [==============================] - 6s 110us/sample - loss: 0.2460 - accuracy: 0.9111 - val_loss: 0.3072 - val_accuracy: 0.8900\n"
     ]
    }
   ],
   "source": [
    "# 定义文件夹和文件\n",
    "logdir = os.path.join('graph_def_and_weights')\n",
    "if not os.path.exists(logdir):\n",
    "    os.mkdir(logdir)\n",
    "output_model_file = os.path.join(logdir, 'fashion_mnist_model.h5')\n",
    "\n",
    "# 定义回调函数\n",
    "callbacks = [\n",
    "    keras.callbacks.TensorBoard(log_dir = logdir, profile_batch = 100000000),\n",
    "    # 定义保存模型\n",
    "    keras.callbacks.ModelCheckpoint(\n",
    "        output_model_file, save_best_only = True, save_weights_only = False),\n",
    "    keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3),\n",
    "]\n",
    "history = model.fit(x_train_scaled, y_train, epochs=10,\n",
    "                    validation_data=(x_valid_scaled, y_valid),\n",
    "                    callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 载入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "loaded_model = keras.models.load_model(output_model_file)\n",
    "loaded_model.evaluate(x_test_scaled, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2、只保存参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义文件夹和文件\n",
    "logdir = os.path.join('graph_def_and_weights')\n",
    "if not os.path.exists(logdir):\n",
    "    os.mkdir(logdir)\n",
    "output_model_file = os.path.join(logdir, 'fashion_mnist_weights.h5')\n",
    "\n",
    "# 定义回调函数\n",
    "callbacks = [\n",
    "    keras.callbacks.TensorBoard(log_dir = logdir, profile_batch = 100000000),\n",
    "    # 定义保存模型\n",
    "    keras.callbacks.ModelCheckpoint(\n",
    "        output_model_file, save_best_only = True, save_weights_only = True),\n",
    "    keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3),\n",
    "]\n",
    "history = model.fit(x_train_scaled, y_train, epochs=10,\n",
    "                    validation_data=(x_valid_scaled, y_valid),\n",
    "                    callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 加载参数\n",
    "在另一个文件中使用时"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = os.path.join('graph_def_and_weights')\n",
    "if not os.path.exists(logdir):\n",
    "    os.mkdir(logdir)\n",
    "output_model_file = os.path.join(logdir, 'fashion_mnist_weights.h5')\n",
    "\n",
    "model.load_weights(output_model_file)\n",
    "model.evaluate(x_test_scaled, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3、另一种方法只保存参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = os.path.join('graph_def_and_weights')\n",
    "if not os.path.exists(logdir):\n",
    "    os.mkdir(logdir)\n",
    "\n",
    "history = model.fit(x_train_scaled, y_train, epochs=10,\n",
    "                    validation_data=(x_valid_scaled, y_valid))\n",
    "\n",
    "model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 加载参数\n",
    "在另一个文件中使用时"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = os.path.join('graph_def_and_weights')\n",
    "if not os.path.exists(logdir):\n",
    "    os.mkdir(logdir)\n",
    "output_model_file = os.path.join(logdir, 'fashion_mnist_weights_2.h5')\n",
    "\n",
    "model.load_weights(output_model_file)\n",
    "model.evaluate(x_test_scaled, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Keras模型转化为SavedModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten_1 (Flatten)          (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 300)               235500    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 100)               30100     \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 10)                1010      \n",
      "=================================================================\n",
      "Total params: 266,610\n",
      "Trainable params: 266,610\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential([\n",
    "    keras.layers.Flatten(input_shape=[28, 28]),\n",
    "    keras.layers.Dense(300, activation='relu'),\n",
    "    keras.layers.Dense(100, activation='relu'),\n",
    "    keras.layers.Dense(10, activation='softmax')\n",
    "])\n",
    "model.summary()\n",
    "\n",
    "# 模型编译，固化模型\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "              optimizer='sgd',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 55000 samples, validate on 5000 samples\n",
      "Epoch 1/10\n",
      "55000/55000 [==============================] - 7s 126us/sample - loss: 0.5293 - accuracy: 0.8130 - val_loss: 0.4131 - val_accuracy: 0.8542\n",
      "Epoch 2/10\n",
      "55000/55000 [==============================] - 6s 108us/sample - loss: 0.3923 - accuracy: 0.8595 - val_loss: 0.3637 - val_accuracy: 0.8720\n",
      "Epoch 3/10\n",
      "55000/55000 [==============================] - 6s 112us/sample - loss: 0.3530 - accuracy: 0.8738 - val_loss: 0.3747 - val_accuracy: 0.8614\n",
      "Epoch 4/10\n",
      "55000/55000 [==============================] - 6s 110us/sample - loss: 0.3291 - accuracy: 0.8823 - val_loss: 0.3323 - val_accuracy: 0.8774\n",
      "Epoch 5/10\n",
      "55000/55000 [==============================] - 6s 118us/sample - loss: 0.3096 - accuracy: 0.8880 - val_loss: 0.3218 - val_accuracy: 0.8878\n",
      "Epoch 6/10\n",
      "55000/55000 [==============================] - 6s 115us/sample - loss: 0.2934 - accuracy: 0.8938 - val_loss: 0.3150 - val_accuracy: 0.8856\n",
      "Epoch 7/10\n",
      "55000/55000 [==============================] - 6s 116us/sample - loss: 0.2794 - accuracy: 0.8982 - val_loss: 0.3153 - val_accuracy: 0.8864\n",
      "Epoch 8/10\n",
      "55000/55000 [==============================] - 6s 109us/sample - loss: 0.2684 - accuracy: 0.9020 - val_loss: 0.2957 - val_accuracy: 0.8956\n",
      "Epoch 9/10\n",
      "55000/55000 [==============================] - 6s 110us/sample - loss: 0.2559 - accuracy: 0.9075 - val_loss: 0.2985 - val_accuracy: 0.8906\n",
      "Epoch 10/10\n",
      "55000/55000 [==============================] - 6s 109us/sample - loss: 0.2469 - accuracy: 0.9109 - val_loss: 0.2954 - val_accuracy: 0.8944\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(x_train_scaled, y_train, epochs=10,\n",
    "                    validation_data=(x_valid_scaled, y_valid))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: keras_saved_graph\\assets\n"
     ]
    }
   ],
   "source": [
    "logdir = os.path.join('keras_saved_graph')\n",
    "if not os.path.exists(logdir):\n",
    "    os.mkdir(logdir)\n",
    "\n",
    "tf.saved_model.save(model, logdir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用工具查看和测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:\n",
      "\n",
      "signature_def['__saved_model_init_op']:\n",
      "  The given SavedModel SignatureDef contains the following input(s):\n",
      "  The given SavedModel SignatureDef contains the following output(s):\n",
      "    outputs['__saved_model_init_op'] tensor_info:\n",
      "        dtype: DT_INVALID\n",
      "        shape: unknown_rank\n",
      "        name: NoOp\n",
      "  Method name is: \n",
      "\n",
      "signature_def['serving_default']:\n",
      "  The given SavedModel SignatureDef contains the following input(s):\n",
      "    inputs['flatten_1_input'] tensor_info:\n",
      "        dtype: DT_FLOAT\n",
      "        shape: (-1, 28, 28)\n",
      "        name: serving_default_flatten_1_input:0\n",
      "  The given SavedModel SignatureDef contains the following output(s):\n",
      "    outputs['dense_2'] tensor_info:\n",
      "        dtype: DT_FLOAT\n",
      "        shape: (-1, 10)\n",
      "        name: StatefulPartitionedCall:0\n",
      "  Method name is: tensorflow/serving/predict\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-10 12:35:17.031604: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --dir ./keras_saved_graph --all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The given SavedModel SignatureDef contains the following input(s):\n",
      "  inputs['flatten_1_input'] tensor_info:\n",
      "      dtype: DT_FLOAT\n",
      "      shape: (-1, 28, 28)\n",
      "      name: serving_default_flatten_1_input:0\n",
      "The given SavedModel SignatureDef contains the following output(s):\n",
      "  outputs['dense_2'] tensor_info:\n",
      "      dtype: DT_FLOAT\n",
      "      shape: (-1, 10)\n",
      "      name: StatefulPartitionedCall:0\n",
      "Method name is: tensorflow/serving/predict\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-10 12:43:58.013512: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli show --dir ./keras_saved_graph --tag_set serve --signature_def serving_default"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result for output key dense_2:\n",
      "[[0.07866267 0.04084459 0.0750244  0.03414746 0.09606657 0.01821422\n",
      "  0.09031539 0.03132759 0.52319443 0.0122027 ]\n",
      " [0.07866267 0.04084459 0.0750244  0.03414746 0.09606657 0.01821422\n",
      "  0.09031539 0.03132759 0.52319443 0.0122027 ]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-10 12:52:31.779954: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll\n",
      "2020-02-10 12:52:34.206038: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library nvcuda.dll\n",
      "2020-02-10 12:52:34.773939: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1618] Found device 0 with properties: \n",
      "name: GeForce GTX 850M major: 5 minor: 0 memoryClockRate(GHz): 0.9015\n",
      "pciBusID: 0000:0a:00.0\n",
      "2020-02-10 12:52:34.774145: I tensorflow/stream_executor/platform/default/dlopen_checker_stub.cc:25] GPU libraries are statically linked, skip dlopen check.\n",
      "2020-02-10 12:52:34.775123: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Adding visible gpu devices: 0\n",
      "2020-02-10 12:52:34.775552: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2\n",
      "2020-02-10 12:52:34.778724: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1618] Found device 0 with properties: \n",
      "name: GeForce GTX 850M major: 5 minor: 0 memoryClockRate(GHz): 0.9015\n",
      "pciBusID: 0000:0a:00.0\n",
      "2020-02-10 12:52:34.778981: I tensorflow/stream_executor/platform/default/dlopen_checker_stub.cc:25] GPU libraries are statically linked, skip dlopen check.\n",
      "2020-02-10 12:52:34.779989: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Adding visible gpu devices: 0\n",
      "2020-02-10 12:52:36.218809: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1159] Device interconnect StreamExecutor with strength 1 edge matrix:\n",
      "2020-02-10 12:52:36.218967: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1165]      0 \n",
      "2020-02-10 12:52:36.219056: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1178] 0:   N \n",
      "2020-02-10 12:52:36.220289: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1304] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 3040 MB memory) -> physical GPU (device: 0, name: GeForce GTX 850M, pci bus id: 0000:0a:00.0, compute capability: 5.0)\n",
      "WARNING:tensorflow:From f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\tools\\saved_model_cli.py:339: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.\n",
      "2020-02-10 12:52:36.663700: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cublas64_100.dll\n"
     ]
    }
   ],
   "source": [
    "!saved_model_cli run --dir ./keras_saved_graph --tag_set serve --signature_def serving_default \\\n",
    "    --input_exprs \"flatten_1_input=np.ones((2, 28, 28))\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用程序进行验证和推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['serving_default']\n"
     ]
    }
   ],
   "source": [
    "# 加载模型\n",
    "loaded_saved_model = tf.saved_model.load('./keras_saved_graph')\n",
    "# 列出所有签名函数\n",
    "print(list(loaded_saved_model.signatures.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tensorflow.python.saved_model.load._WrapperFunction object at 0x0000025D0A4A4DA0>\n"
     ]
    }
   ],
   "source": [
    "# 利用签名获得函数句柄\n",
    "inference = loaded_saved_model.signatures['serving_default']\n",
    "print(inference)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'dense_2': TensorSpec(shape=(None, 10), dtype=tf.float32, name='dense_2')}\n"
     ]
    }
   ],
   "source": [
    "# 打印输出结构\n",
    "print(inference.structured_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'dense_2': <tf.Tensor: id=622, shape=(1, 10), dtype=float32, numpy=\n",
      "array([[1.8168412e-06, 8.1950509e-07, 9.4149618e-06, 7.7796500e-07,\n",
      "        1.4993745e-05, 1.9446755e-02, 2.9783282e-06, 3.8230207e-02,\n",
      "        2.2591450e-04, 9.4206625e-01]], dtype=float32)>}\n",
      "tf.Tensor(\n",
      "[[1.8168412e-06 8.1950509e-07 9.4149618e-06 7.7796500e-07 1.4993745e-05\n",
      "  1.9446755e-02 2.9783282e-06 3.8230207e-02 2.2591450e-04 9.4206625e-01]], shape=(1, 10), dtype=float32)\n",
      "[[1.8168412e-06 8.1950509e-07 9.4149618e-06 7.7796500e-07 1.4993745e-05\n",
      "  1.9446755e-02 2.9783282e-06 3.8230207e-02 2.2591450e-04 9.4206625e-01]]\n"
     ]
    }
   ],
   "source": [
    "# 使用inference进行推理\n",
    "results = inference(tf.constant(x_test_scaled[0:1]))\n",
    "print(results)\n",
    "print(results['dense_2'])\n",
    "print(results['dense_2'].numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
